{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Semantic search with FAISS (PyTorch)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Install the Transformers, Datasets, and Evaluate libraries to run this notebook."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install datasets evaluate transformers[sentencepiece]\n",
    "!pip install faiss-gpu"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Dataset({\n",
       "    features: ['url', 'repository_url', 'labels_url', 'comments_url', 'events_url', 'html_url', 'id', 'node_id', 'number', 'title', 'user', 'labels', 'state', 'locked', 'assignee', 'assignees', 'milestone', 'comments', 'created_at', 'updated_at', 'closed_at', 'author_association', 'active_lock_reason', 'pull_request', 'body', 'performed_via_github_app', 'is_pull_request'],\n",
       "    num_rows: 2855\n",
       "})"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from datasets import load_dataset\n",
    "\n",
    "issues_dataset = load_dataset(\"lewtun/github-issues\", split=\"train\")\n",
    "issues_dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Dataset({\n",
       "    features: ['url', 'repository_url', 'labels_url', 'comments_url', 'events_url', 'html_url', 'id', 'node_id', 'number', 'title', 'user', 'labels', 'state', 'locked', 'assignee', 'assignees', 'milestone', 'comments', 'created_at', 'updated_at', 'closed_at', 'author_association', 'active_lock_reason', 'pull_request', 'body', 'performed_via_github_app', 'is_pull_request'],\n",
       "    num_rows: 771\n",
       "})"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "issues_dataset = issues_dataset.filter(\n",
    "    lambda x: (x[\"is_pull_request\"] == False and len(x[\"comments\"]) > 0)\n",
    ")\n",
    "issues_dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Dataset({\n",
       "    features: ['html_url', 'title', 'comments', 'body'],\n",
       "    num_rows: 771\n",
       "})"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "columns = issues_dataset.column_names\n",
    "columns_to_keep = [\"title\", \"body\", \"html_url\", \"comments\"]\n",
    "columns_to_remove = set(columns_to_keep).symmetric_difference(columns)\n",
    "issues_dataset = issues_dataset.remove_columns(columns_to_remove)\n",
    "issues_dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "issues_dataset.set_format(\"pandas\")\n",
    "df = issues_dataset[:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['the bug code locate in ：\\r\\n    if data_args.task_name is not None:\\r\\n        # Downloading and loading a dataset from the hub.\\r\\n        datasets = load_dataset(\"glue\", data_args.task_name, cache_dir=model_args.cache_dir)',\n",
       " 'Hi @jinec,\\r\\n\\r\\nFrom time to time we get this kind of `ConnectionError` coming from the github.com website: https://raw.githubusercontent.com\\r\\n\\r\\nNormally, it should work if you wait a little and then retry.\\r\\n\\r\\nCould you please confirm if the problem persists?',\n",
       " 'cannot connect，even by Web browser，please check that  there is some  problems。',\n",
       " 'I can access https://raw.githubusercontent.com/huggingface/datasets/1.7.0/datasets/glue/glue.py without problem...']"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df[\"comments\"][0].tolist()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "comments_df = df.explode(\"comments\", ignore_index=True)\n",
    "comments_df.head(4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Dataset({\n",
       "    features: ['html_url', 'title', 'comments', 'body'],\n",
       "    num_rows: 2842\n",
       "})"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from datasets import Dataset\n",
    "\n",
    "comments_dataset = Dataset.from_pandas(comments_df)\n",
    "comments_dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "comments_dataset = comments_dataset.map(\n",
    "    lambda x: {\"comment_length\": len(x[\"comments\"].split())}\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Dataset({\n",
       "    features: ['html_url', 'title', 'comments', 'body', 'comment_length'],\n",
       "    num_rows: 2098\n",
       "})"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "comments_dataset = comments_dataset.filter(lambda x: x[\"comment_length\"] > 15)\n",
    "comments_dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def concatenate_text(examples):\n",
    "    return {\n",
    "        \"text\": examples[\"title\"]\n",
    "        + \" \\n \"\n",
    "        + examples[\"body\"]\n",
    "        + \" \\n \"\n",
    "        + examples[\"comments\"]\n",
    "    }\n",
    "\n",
    "\n",
    "comments_dataset = comments_dataset.map(concatenate_text)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoTokenizer, AutoModel\n",
    "\n",
    "model_ckpt = \"sentence-transformers/multi-qa-mpnet-base-dot-v1\"\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_ckpt)\n",
    "model = AutoModel.from_pretrained(model_ckpt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "device = torch.device(\"cuda\")\n",
    "model.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def cls_pooling(model_output):\n",
    "    return model_output.last_hidden_state[:, 0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_embeddings(text_list):\n",
    "    encoded_input = tokenizer(\n",
    "        text_list, padding=True, truncation=True, return_tensors=\"pt\"\n",
    "    )\n",
    "    encoded_input = {k: v.to(device) for k, v in encoded_input.items()}\n",
    "    model_output = model(**encoded_input)\n",
    "    return cls_pooling(model_output)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([1, 768])"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "embedding = get_embeddings(comments_dataset[\"text\"][0])\n",
    "embedding.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "embeddings_dataset = comments_dataset.map(\n",
    "    lambda x: {\"embeddings\": get_embeddings(x[\"text\"]).detach().cpu().numpy()[0]}\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "embeddings_dataset.add_faiss_index(column=\"embeddings\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([1, 768])"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "question = \"How can I load a dataset offline?\"\n",
    "question_embedding = get_embeddings([question]).cpu().detach().numpy()\n",
    "question_embedding.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "scores, samples = embeddings_dataset.get_nearest_examples(\n",
    "    \"embeddings\", question_embedding, k=5\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "\n",
    "samples_df = pd.DataFrame.from_dict(samples)\n",
    "samples_df[\"scores\"] = scores\n",
    "samples_df.sort_values(\"scores\", ascending=False, inplace=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "\"\"\"\n",
       "COMMENT: Requiring online connection is a deal breaker in some cases unfortunately so it'd be great if offline mode is added similar to how `transformers` loads models offline fine.\n",
       "\n",
       "@mandubian's second bullet point suggests that there's a workaround allowing you to use your offline (custom?) dataset with `datasets`. Could you please elaborate on how that should look like?\n",
       "SCORE: 25.505046844482422\n",
       "TITLE: Discussion using datasets in offline mode\n",
       "URL: https://github.com/huggingface/datasets/issues/824\n",
       "==================================================\n",
       "\n",
       "COMMENT: The local dataset builders (csv, text , json and pandas) are now part of the `datasets` package since #1726 :)\n",
       "You can now use them offline\n",
       "\\`\\`\\`python\n",
       "datasets = load_dataset(\"text\", data_files=data_files)\n",
       "\\`\\`\\`\n",
       "\n",
       "We'll do a new release soon\n",
       "SCORE: 24.555509567260742\n",
       "TITLE: Discussion using datasets in offline mode\n",
       "URL: https://github.com/huggingface/datasets/issues/824\n",
       "==================================================\n",
       "\n",
       "COMMENT: I opened a PR that allows to reload modules that have already been loaded once even if there's no internet.\n",
       "\n",
       "Let me know if you know other ways that can make the offline mode experience better. I'd be happy to add them :)\n",
       "\n",
       "I already note the \"freeze\" modules option, to prevent local modules updates. It would be a cool feature.\n",
       "\n",
       "----------\n",
       "\n",
       "> @mandubian's second bullet point suggests that there's a workaround allowing you to use your offline (custom?) dataset with `datasets`. Could you please elaborate on how that should look like?\n",
       "\n",
       "Indeed `load_dataset` allows to load remote dataset script (squad, glue, etc.) but also you own local ones.\n",
       "For example if you have a dataset script at `./my_dataset/my_dataset.py` then you can do\n",
       "\\`\\`\\`python\n",
       "load_dataset(\"./my_dataset\")\n",
       "\\`\\`\\`\n",
       "and the dataset script will generate your dataset once and for all.\n",
       "\n",
       "----------\n",
       "\n",
       "About I'm looking into having `csv`, `json`, `text`, `pandas` dataset builders already included in the `datasets` package, so that they are available offline by default, as opposed to the other datasets that require the script to be downloaded.\n",
       "cf #1724\n",
       "SCORE: 24.14896583557129\n",
       "TITLE: Discussion using datasets in offline mode\n",
       "URL: https://github.com/huggingface/datasets/issues/824\n",
       "==================================================\n",
       "\n",
       "COMMENT: > here is my way to load a dataset offline, but it **requires** an online machine\n",
       ">\n",
       "> 1. (online machine)\n",
       ">\n",
       "> ```\n",
       ">\n",
       "> import datasets\n",
       ">\n",
       "> data = datasets.load_dataset(...)\n",
       ">\n",
       "> data.save_to_disk(/YOUR/DATASET/DIR)\n",
       ">\n",
       "> ```\n",
       ">\n",
       "> 2. copy the dir from online to the offline machine\n",
       ">\n",
       "> 3. (offline machine)\n",
       ">\n",
       "> ```\n",
       ">\n",
       "> import datasets\n",
       ">\n",
       "> data = datasets.load_from_disk(/SAVED/DATA/DIR)\n",
       ">\n",
       "> ```\n",
       ">\n",
       ">\n",
       ">\n",
       "> HTH.\n",
       "\n",
       "\n",
       "SCORE: 22.893993377685547\n",
       "TITLE: Discussion using datasets in offline mode\n",
       "URL: https://github.com/huggingface/datasets/issues/824\n",
       "==================================================\n",
       "\n",
       "COMMENT: here is my way to load a dataset offline, but it **requires** an online machine\n",
       "1. (online machine)\n",
       "\\`\\`\\`\n",
       "import datasets\n",
       "data = datasets.load_dataset(...)\n",
       "data.save_to_disk(/YOUR/DATASET/DIR)\n",
       "\\`\\`\\`\n",
       "2. copy the dir from online to the offline machine\n",
       "3. (offline machine)\n",
       "\\`\\`\\`\n",
       "import datasets\n",
       "data = datasets.load_from_disk(/SAVED/DATA/DIR)\n",
       "\\`\\`\\`\n",
       "\n",
       "HTH.\n",
       "SCORE: 22.406635284423828\n",
       "TITLE: Discussion using datasets in offline mode\n",
       "URL: https://github.com/huggingface/datasets/issues/824\n",
       "==================================================\n",
       "\"\"\""
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "for _, row in samples_df.iterrows():\n",
    "    print(f\"COMMENT: {row.comments}\")\n",
    "    print(f\"SCORE: {row.scores}\")\n",
    "    print(f\"TITLE: {row.title}\")\n",
    "    print(f\"URL: {row.html_url}\")\n",
    "    print(\"=\" * 50)\n",
    "    print()"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "name": "Semantic search with FAISS (PyTorch)",
   "provenance": []
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
